-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added kl_divergence for multivariate normals #1654
Conversation
4d999ca
to
304cab7
Compare
304cab7
to
4e3beae
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
making the linter tests happy
Thanks for the feedback @fehiepsi . I have modified the code according to your suggestions and also added a test case where the batch dimensions of the two distributions are not identical (and fixed the linter complaints). If everything is good now, I'd still rebase onto the recent master and merge the commits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @lumip! I have a few small comments on the assertions.
numpyro/distributions/kl.py
Outdated
|
||
Lq_inv = solve_triangular(q_scale_tril, jnp.eye(D), lower=True) | ||
q_half_log_det = jnp.log(jnp.diagonal(q.scale_tril, axis1=-2, axis2=-1)).sum(-1) | ||
assert q_half_log_det.shape == q.batch_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this assertion might not be true. In MultivariateNormal implementation, we avoid unnecessary broadcasting (e.g. we can have a batch of means with a single scale_tril).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I wasn't thinking of that. Removed the assertion and added tests for those cases.
numpyro/distributions/kl.py
Outdated
f" {p.event_shape} and {q.event_shape} for p and q, respectively." | ||
) | ||
|
||
if p.batch_shape != q.batch_shape: | ||
min_batch_ndim = min(len(p.batch_shape), len(q.batch_shape)) | ||
if p.batch_shape[-min_batch_ndim:] != q.batch_shape[-min_batch_ndim:]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about only assert that p.batch_shape and q.batch_shape can be broadcasted.
try:
result_batch_shape = jnp.broadcast_shapes(p.batch_shape, q.batch_shape)
except ValueError:
raise ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
numpyro/distributions/kl.py
Outdated
assert jnp.ndim(q_mean) == 1 | ||
assert jnp.ndim(p_scale_tril) == 2 | ||
assert jnp.ndim(q_scale_tril) == 2 | ||
assert q.mean.shape == q.batch_shape + q.event_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those assertions are unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
numpyro/distributions/kl.py
Outdated
|
||
return .5 * (tr + t1 - D - log_det_ratio) | ||
tr = _batch_trace_from_cholesky(Lq_inv @ p.scale_tril) | ||
assert tr.shape == result_batch_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assertion might not be true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
numpyro/distributions/kl.py
Outdated
p_mean_flat = jnp.reshape(p.mean, (-1, D)) | ||
p_scale_tril_flat = jnp.reshape(p.scale_tril, (-1, D, D)) | ||
t1 = jnp.square(Lq_inv @ (p.loc - q.loc)[..., jnp.newaxis]).sum((-2, -1)) | ||
assert t1.shape == result_batch_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assertion might not be true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
numpyro/distributions/kl.py
Outdated
).sum(-1) | ||
log_det_ratio = 2 * (p_half_log_det - q_half_log_det) | ||
p_half_log_det = jnp.log(jnp.diagonal(p.scale_tril, axis1=-2, axis2=-1)).sum(-1) | ||
assert p_half_log_det.shape == p.batch_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assertion might not be true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
test/test_distributions.py
Outdated
((), ()), | ||
((1,), (1,)), | ||
((2, 3), (2, 3)), | ||
((5, 2, 3), (2, 3)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you change this to (5, 1, 3)
and (2, 3)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
test/test_distributions.py
Outdated
((1,), (1,)), | ||
((2, 3), (2, 3)), | ||
((5, 2, 3), (2, 3)), | ||
((2, 3), (5, 2, 3)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe ((1, 3), (5, 2, 3))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @lumip!
KL implementation for
MultivariateNormal
EDIT: just saw that there is already a pull request for this from a year ago (#1487). What's the hold up with that?